package drds.plus.rule_engine.rule_calculate;

import drds.plus.common.model.comparative.ColumnNameToComparativeMapChoicer;
import drds.plus.common.model.comparative.Comparative;
import drds.plus.rule_engine.rule.Rule;
import drds.plus.rule_engine.rule.RuleColumn;
import drds.plus.rule_engine.table_rule.ITableRule;
import drds.plus.rule_engine.utils.RuleUtils;
import drds.plus.rule_engine.utils.sample.Samples;

import java.util.*;


public class RuleCalculate {


    private static <T> Rule<T> fillColumnNameToComparativeMap(Map<String, Comparative> columnNameToComparativeMap,//
                                                              Map<String, Comparative> matchColumnNameToComparativeMap,//
                                                              ColumnNameToComparativeMapChoicer columnNameToComparativeMapChoicer, //
                                                              List<Object> args,//
                                                              ITableRule<String, String> tableRule,
                                                              Rule<T> rule,//
                                                              boolean forceAllowFullTableScan) {//
        Rule<T> matchedRule = null;
        if (rule != null) {
            matchedRule = fillColumnNameToComparativeMap(rule, columnNameToComparativeMapChoicer, columnNameToComparativeMap, matchColumnNameToComparativeMap, args);
            if (matchedRule == null) {
                // 有分库或分表规则，但是没有匹配到，是否执行全部扫描
                if (!(tableRule.isAllowFullTableScan() || forceAllowFullTableScan)) {
                    List<Set<String>> shardColumns = new LinkedList<Set<String>>();
                    Set<String> columnNameSet = new LinkedHashSet<String>();
                    for (RuleColumn ruleColumn : rule.getRuleColumnSet()) {
                        columnNameSet.add(ruleColumn.columnName);
                    }
                    shardColumns.add(columnNameSet);
                    throw new IllegalArgumentException("sql_process contain no sharding column:" + shardColumns + " , maybe you can set allowFullTableScan");
                }
            }
        }
        return matchedRule;
    }

    private static <T> Rule<T> fillColumnNameToComparativeMap(Rule<T> rule, ColumnNameToComparativeMapChoicer columnNameToComparativeMapChoicer, Map<String, Comparative> columnNameToComparativeMap, Map<String, Comparative> matchColumnNameToComparativeMap, List<Object> args) {
        // 优先匹配必选列
        {
            matchColumnNameToComparativeMap.clear();
            for (RuleColumn ruleColumn : rule.getColumnNameToRuleColumnMap().values()) {
                Comparative comparative = getComparative(ruleColumn.columnName, columnNameToComparativeMap, columnNameToComparativeMapChoicer, args);
                if (comparative == null) {
                    break;
                }
                matchColumnNameToComparativeMap.put(ruleColumn.columnName, comparative);
            }
            if (matchColumnNameToComparativeMap.size() == rule.getColumnNameToRuleColumnMap().size()) {
                return rule; // 完全匹配
            }
        }

        // 匹配必选列 + 可选列
        {
            matchColumnNameToComparativeMap.clear();
            int mandatoryColumnCount = 0;
            for (RuleColumn ruleColumn : rule.getColumnNameToRuleColumnMap().values()) {
                if (ruleColumn.optional) {
                    continue;
                }

                mandatoryColumnCount++;
                Comparative comparative = getComparative(ruleColumn.columnName, columnNameToComparativeMap, columnNameToComparativeMapChoicer, args);
                if (comparative == null) {
                    break;
                }
                matchColumnNameToComparativeMap.put(ruleColumn.columnName, comparative);
            }

            if (mandatoryColumnCount != 0 && matchColumnNameToComparativeMap.size() == mandatoryColumnCount) {
                return rule; // 必选列匹配
            }
        }

        // 针对没有必选列的规则如：rule_engine=..#a?#..#b?#.. 并且只有a或者b列在sql中有
        {
            matchColumnNameToComparativeMap.clear();
            for (RuleColumn ruleColumn : rule.getColumnNameToRuleColumnMap().values()) {
                if (!ruleColumn.optional) {
                    break; // 如果当前规则有必选项，直接跳过,因为走到这里必选列已经不匹配了
                }

                Comparative comparative = getComparative(ruleColumn.columnName, columnNameToComparativeMap, columnNameToComparativeMapChoicer, args);
                if (comparative != null) {
                    matchColumnNameToComparativeMap.put(ruleColumn.columnName, columnNameToComparativeMap.get(ruleColumn.columnName));
                }
            }

            if (matchColumnNameToComparativeMap.size() != 0) {
                return rule; // 第一个全是可选列的规则，并且args包含该规则的部分可选列
            }
        }
        return null;
    }


    private static Comparative getComparative(String columnName, Map<String, Comparative> columnNameToComparativeMap, ColumnNameToComparativeMapChoicer columnNameToComparativeMapChoicer, List<Object> args) {
        Comparative comparative = columnNameToComparativeMap.get(columnName); // 先从缓存中获取
        if (comparative == null) {
            comparative = columnNameToComparativeMapChoicer.getColumnComparative(columnName, args);
            if (comparative != null) {
                columnNameToComparativeMap.put(columnName, comparative); // 放入缓存
            }
        }
        return comparative;
    }

    public RuleCalculateResult ruleCalculate(ColumnNameToComparativeMapChoicer columnNameToComparativeMapChoicer, List<Object> args, ITableRule<String, String> tableRule, boolean needSourceKey, boolean forceAllowFullTableScan) {
        return ruleCalculate(columnNameToComparativeMapChoicer, args, tableRule, forceAllowFullTableScan);
    }


    private RuleCalculateResult ruleCalculate(ColumnNameToComparativeMapChoicer columnNameToComparativeMapChoicer, List<Object> args, ITableRule<String, String> tableRule, boolean forceAllowFullTableScan) {
        Map<String, Comparative> allColumnNameToComparativeMap = new HashMap<String, Comparative>(2);
        Map<String, Comparative> matchColumnNameToComparativeMap = new HashMap<String, Comparative>(2); // 匹配的规则所对应的列名和比较树
        Object outerContext = tableRule.getOuterContext();
        Rule<String> rule = fillColumnNameToComparativeMap(allColumnNameToComparativeMap, matchColumnNameToComparativeMap, columnNameToComparativeMapChoicer, args, tableRule, tableRule.getRule(), forceAllowFullTableScan);

        Map<String, Map<String, ColumnNameToValueSetMap>> topology;
        {
            Map<String, Samples> stringSamplesMap = RuleUtils.cast(rule.calculate(matchColumnNameToComparativeMap, outerContext));
            topology = new HashMap<String, Map<String, ColumnNameToValueSetMap>>(tableRule.getActualTopology().size());
            for (String dbValue : tableRule.getActualTopology().keySet()) {
                topology.put(dbValue, toMapField(stringSamplesMap));
            }
        }

        return new RuleCalculateResult(buildTargetDbListWithSourceKey(topology), matchColumnNameToComparativeMap);
    }


    private Map<String, ColumnNameToValueSetMap> toMapField(Map<String/* rule计算结果 */, Samples/* 得到该结果的样本 */> stringSamplesMap) {
        Map<String, ColumnNameToValueSetMap> res = new HashMap<String, ColumnNameToValueSetMap>(stringSamplesMap.size());
        for (Map.Entry<String, Samples> e : stringSamplesMap.entrySet()) {
            ColumnNameToValueSetMap f = new ColumnNameToValueSetMap(e.getValue().size());
            f.setColumnNameToValueSetMap(e.getValue().getColumnEnumerates());
            res.put(e.getKey(), f);
        }
        return res;
    }


    private List<DataNodeDataScatterInfo> buildTargetDbListWithSourceKey(Map<String, Map<String, ColumnNameToValueSetMap>> stringMapMap) {
        List<DataNodeDataScatterInfo> dataNodeDataScatterInfoList = new ArrayList<DataNodeDataScatterInfo>(stringMapMap.size());
        for (Map.Entry<String, Map<String, ColumnNameToValueSetMap>> entry : stringMapMap.entrySet()) {
            DataNodeDataScatterInfo dataNodeDataScatterInfo = new DataNodeDataScatterInfo();
            dataNodeDataScatterInfo.setDataNodeId(entry.getKey());
            dataNodeDataScatterInfo.setTableNameToColumnNameToValueSetMapMap(entry.getValue());
            dataNodeDataScatterInfoList.add(dataNodeDataScatterInfo);
        }
        return dataNodeDataScatterInfoList;
    }

}
